package server

import (
	"epg/conf"
	"epg/stats"
	"epg/utils"
	"fmt"
	"net"
	"runtime"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/pingcap/tidb/mysql"

	"github.com/zeast/logs"
	"context"
)

var (
	nowFunc = time.Now
)

//MysqlDB mysql connecter.
type MysqlDB struct {
	alias string
	mu    sync.RWMutex
	cfg   conf.NodeConf

	addr   string
	user   string
	passwd string
	//state  mysql.StatusFlag

	once        *sync.Once
	stop        bool
	close       chan struct{}
	maxLifetime time.Duration
	freeConn    chan *mysqlConn
	openCh      chan struct{}
	_           [8]uint64
	maxIdle     int32
	_           [8]uint64
	maxOpen     int32
	_           [8]uint64
	numOpen     int32
	reConnect   *time.Ticker
	bad         bool
}

//Open open with config.
func Open(conf *conf.NodeConf) *MysqlDB {
	m := &MysqlDB{
		cfg:         *conf,
		alias:       conf.Alias,
		addr:        conf.Addr,
		user:        conf.User,
		passwd:      conf.Passwd,
		once:        &sync.Once{},
		maxIdle:     conf.MaxIdleConn,
		maxOpen:     conf.MaxConnNum,
		maxLifetime: time.Duration(conf.MaxLifeTime) * time.Second,
	}
	m.close = make(chan struct{})
	m.freeConn = make(chan *mysqlConn, dbMaxConn)
	m.openCh = make(chan struct{}, m.maxOpen)

	go m.opener()
	go m.statsTask()
	for i := int32(0); i < m.maxIdle/2; i++ {
		m.openCh <- struct{}{}
	}
	return m
}

func (m *MysqlDB) setAttr(key string, val int32) {
	m.mu.Lock()
	switch key {
	case "MaxLifeTime":
		m.maxLifetime = time.Duration(val) * time.Second
		m.cfg.MaxLifeTime = val
	case "MaxIdleConn":
		atomic.StoreInt32(&m.maxIdle, val)
		m.cfg.MaxIdleConn = val
	case "MaxConnNum":
		atomic.StoreInt32(&m.maxOpen, val)
		m.cfg.MaxConnNum = val
	}
	m.mu.Unlock()
}

//NewConnect open the database connect.
func (m *MysqlDB) newConn() (*mysqlConn, error) {
	var err error

	// New mysqlConn
	mc := &mysqlConn{
		maxPacketAllowed: MaxPacketSize,
		maxWriteSize:     MaxPacketSize - 1,
		writeTimeout:     time.Duration(clientWriteTimeout),
		pushedAt:         nowFunc(),
		buf:              newBufPool(),
	}
	mc.cfg = &conf.NodeConf{
		Addr:   m.addr,
		User:   m.user,
		Passwd: m.passwd,
	}

	mc.strict = false //mc.cfg.Strict

	// Connect to Server
	mc.netConn, err = net.DialTimeout("tcp", mc.cfg.Addr, 3*time.Second)
	if err != nil {
		return nil, err
	}

	// Enable TCP Keepalives on TCP connections
	if tc, ok := mc.netConn.(*net.TCPConn); ok {
		if err = tc.SetKeepAlive(true); err != nil {
			// Don't send COM_QUIT before handshake.
			mc.Close()
			return nil, err
		}
	}

	mc.io = newPacketIO(mc.netConn)

	// Set I/O timeouts
	mc.io.setReadTimeout(conf.GlobalConfig.Proxy.ReadTimeout)

	// Reading Handshake Initialization Packet
	cipher, err := mc.readInitPacket()
	if err != nil {
		mc.Close()
		return nil, err
	}
	// Send Client Authentication Packet
	if err = mc.writeAuthPacket(cipher); err != nil {
		mc.Close()
		return nil, err
	}

	// Handle response to auth packet, switch methods if possible
	if err := mc.readInitOK(); err != nil {
		mc.Close()
		return nil, err
	}
	atomic.AddInt32(&m.numOpen, 1)
	return mc, nil
}

func (m *MysqlDB) opener() {
	defer utils.PrintPanicStack()
	t := time.After(dbReleaseTick)
	for {
		select {
		case <-m.openCh:
			i := len(m.openCh)
			if i > 0 && int32(i) > m.maxIdle {
				for x := 0; x < i; x++ {
					<-m.openCh
				}
			}

			err := m.maybeOpenNewConn(int32(i + 1))
			e, ok := err.(*net.OpError)
			if ok && e.Op == "dial" {
				NodeMgr.moveOKDBBad(m.cfg)
			}
		case <-t:
			m.maybeRealeaseConn()
			t = time.After(dbReleaseTick)
		case <-m.close:
			for i := int32(0); i < m.numOpen; i++ {
				co := <-m.freeConn
				co.Close()
			}
			return
		}
	}
}

func (m *MysqlDB) reConnectCheck() {
	if m.stop {
		return
	}

	m.reConnect = time.NewTicker(time.Second)
	defer m.reConnect.Stop()

	for {
		select {
		case <-m.reConnect.C:
			if m.checkOnce() {
				return
			}

		case <-m.close:
			return
		}
	}

}

func (m *MysqlDB) checkOnce() bool {
	if err := m.maybeOpenNewConn(1); err != nil {
		logs.Warnf("重连失败. %s", m.alias)
		return false
	} else {
		logs.Infof("重连成功. %s", m.alias)
		//重新回到正常 db 列表
		NodeMgr.moveBadDBOK(m.cfg)
		return true
	}
}

func (m *MysqlDB) waitFree() {
	for {
		if int(m.numOpen) == len(m.freeConn) {
			return
		}
		time.After(time.Millisecond)
	}
}

func (m *MysqlDB) maybeOpenNewConn(n int32) error {
	m.mu.RLock()
	defer m.mu.RUnlock()

	numLeft := m.maxOpen - int32(len(m.freeConn))
	numCan := m.maxOpen - atomic.LoadInt32(&m.numOpen)
	if numLeft > numCan {
		numLeft = numCan
	}
	if numLeft > n {
		numLeft = n
	}
	if numLeft > 0 {
		for i := int32(0); i < numLeft; i++ {
			conn, err := m.newConn()
			if err != nil {
				logs.Errorf("新建 mysql 连接失败:%v, remote:%s", err, m.addr)
				return err
			}
			select {
			case m.freeConn <- conn:
				return nil
			default:
				logs.Error("空闲连接已满，程序正常情况下不应该运行到此处")
			}
		}
	}
	return nil
}

func (m *MysqlDB) maybeRealeaseConn() error {
	m.mu.RLock()
	defer m.mu.RUnlock()
	//no free conn need to check.
	if len(m.freeConn) == 0 {
		return nil
	}
	mc := <-m.freeConn
	if m.maxLifetime > 0 {
		if m.maxIdle < atomic.LoadInt32(&m.numOpen) && mc.expired(m.maxLifetime) {
			mc.Close()
			atomic.AddInt32(&m.numOpen, -1)
			logs.Debug("Connection closed by: out of maxLifetime")
			return nil
		}
	}
	//when ping error, close the connection and remove the node.
	if err := mc.ping(); err != nil {
		mc.Close()
		atomic.AddInt32(&m.numOpen, -1)
		logs.Debugf("Ping error: %v, %s", err, m.addr)
		return err
	}
	m.freeConn <- mc
	return nil
}

func (m *MysqlDB) closeDB() {
	m.mu.Lock()
	defer m.mu.Unlock()
	if m.stop == true {
		return
	}
	m.stop = true
	close(m.close)
}

//getConn get one connection, it will check can open.
func (m *MysqlDB) getConn(timeCtx context.Context) *mysqlConn {
	select {
	case v := <-m.freeConn:
		return v
	default:
		//if openCh is block, it will be hit here.
		select {
		case m.openCh <- struct{}{}:
		default:
			logs.Info("openCh is full, drop open signle")
		}
		logs.Info("openCh:", len(m.openCh), " db: ", m.alias)
	}
	for {
		select {
		case v := <-m.freeConn:
			return v
		case <-timeCtx.Done():
			//get connection wait timeout
			logs.Infof("get connection %s wait timeout", m.alias)
			return nil
		}
	}
}

func (m *MysqlDB) putConn(mc *mysqlConn) error {
	if mc.lastError == mysql.ErrBadConn {
		mc.Close()
		atomic.AddInt32(&m.numOpen, -1)
		logs.Debugf("Close mysqlConn for lastErr is ErrBadConn")
		return nil
	}
	mc.pushedAt = nowFunc()
	select {
	case m.freeConn <- mc:
		return nil
	default:
		//m.freeConn is bigger than maxOpen, so here will not happen
		mc.Close()
		logs.Debugf("Close mc freeConn full, maxIdle:%v,  %v", m.maxIdle, stack())
		atomic.AddInt32(&m.numOpen, -1)
	}
	return nil
}

func stack() string {
	var buf [2 << 10]byte
	return string(buf[:runtime.Stack(buf[:], false)])
}

func (m *MysqlDB) statsTask() {
	defer utils.PrintPanicStack()
	_stater := stats.Stater
	ip := strings.Replace(strings.Replace(m.addr, ":", "_", -1), ".", "_", -1)
	for {
		select {
		case <-time.After(dbReleaseTick * 2):
			numOpen := atomic.LoadInt32(&m.numOpen)
			free := len(m.freeConn)
			_stater.Gauge(fmt.Sprintf("node.%s.numOpen", ip), uint64(numOpen))               //已经打开的连接数
			_stater.Gauge(fmt.Sprintf("node.%s.freeConn", ip), uint64(free))                 //空闲可用连接数
			_stater.Gauge(fmt.Sprintf("node.%s.chanOpen", ip), uint64(len(m.openCh)))        //等待打开连接数
			_stater.Gauge(fmt.Sprintf("node.%s.inuseConn", ip), uint64(numOpen-int32(free))) //正在使用中的连接数
		}
	}
}
